import cv2 as cv
import numpy as np
import torch
import torchvision
from kan import KAN
import matplotlib.pyplot as plt
import os

def process_data(path):
    data = []
    label = []
    all_folders = os.listdir(path)
    for idx, folder in enumerate(all_folders):
        numbers = os.listdir(os.path.join(path, folder))
        for number in numbers:
            img = cv.imread(os.path.join(path, folder, number), 0)
            img = img.reshape(-1)
            data.append(img)
            label.append(idx)
    return np.array(data), np.array(label)

train_data, train_label = process_data('../MNIST/transformed/TRAIN/')
test_data, test_label = process_data('../MNIST/transformed/TEST/')

print(f"Number of training samples: {len(train_data)}")
print(f"Number of test samples: {len(test_data)}")

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
print(f"Using {device} device")

dataset = {
    "train_input": torch.from_numpy(train_data).float().to(device),
    "train_label": torch.from_numpy(train_label).long().to(device),
    "test_input": torch.from_numpy(test_data).float().to(device),
    "test_label": torch.from_numpy(test_label).long().to(device),
}

def create_kan():
    return KAN(width=[784,10,10], grid=4, k=2)
model = create_kan()

def test_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["test_input"]), dim=1)
        correct = (predictions == dataset["test_label"]).float()
        accuracy = correct.mean()
    return accuracy

def train_acc():
    with torch.no_grad():
        predictions = torch.argmax(model(dataset["train_input"].to('cpu')), dim=1)
        correct = (predictions == dataset["train_label"].to('cpu')).float()
        accuracy = correct.mean()
    return accuracy
 
# Train the model
results = model.fit(
    dataset,
    opt="LBFGS",
    steps=1000,
    batch=512,
    loss_fn=torch.nn.CrossEntropyLoss(),
    metrics=(train_acc, test_acc),
)

torch.save(model.state_dict(), "kan.pth")

del model
model = create_kan()
model.load_state_dict(torch.load("kan.pth", weights_only=True))

acc = test_acc()
print(f"Test accuracy: {acc.item() * 100:.2f}%")
total_params = sum(p.numel() for p in model.parameters())
print(f"total number of parameters: {total_params}")

plt.plot(results["train_acc"], label="train")
plt.plot(results["test_acc"], label="test")
plt.legend()
plt.title("Training and Testing Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.savefig("accuracy_plot.png")
plt.close() 